import torch
import numpy as np
from numba import jit
class Z34:
    def __init__(self, k, max_iter, max_sgd, lr, z):
        self.k = k
        self.z = z
        self.centers = None
        self.max_iter = max_iter
        self.center_full = None
        self.centers_subset = None
        self.history = []
        self.max_sgd = max_sgd
        self.lr = lr
    
    def get_centers_full(self):
        assert self.center_full is not None
        return self.center_full
    def get_centers_subset(self):
        assert self.centers_subset is not None
        return self.centers_subset
    
    def get_original_cost(self,data):
        data_tensor = torch.from_numpy(data).float()
        best_score = None
        best_centers = None
        for i in range(10):
            centers = self.fit(data_tensor)
            score = self.get_cost_for_centers(data_tensor, centers)
            if best_score is None or best_score>score:
                best_score = score
                best_centers = centers
        self.center_full = best_centers
        return self.get_cost_for_centers(data_tensor, best_centers)
    
    def get_subset_and_original(self, data_full, data_subset):
        data_full_tensor = torch.from_numpy(data_full).float()
        data_subset_tensor = torch.from_numpy(data_subset).float()
        centers = self.fit(data_subset_tensor)
        self.centers_subset = centers
        cost_subset = self.get_cost_for_centers(data_subset_tensor, centers)
        cost_original = self.get_cost_for_centers(data_full_tensor, centers)
        # clusters, costs = assign_clusters_faster(data_full, centers.numpy(), self.z)
        # cost_original = np.sum(costs)
        cost_original = total_cost(data_full, centers.numpy(),self.z)
        return cost_subset, cost_original
    
    def check_convergence(self):
        if len(self.history) < 5:
            return False
        for i in range(1, 5):
            if self.history[-1]<0.99*self.history[-i]:
                return False
        return True 
    
    def fit(self, data):
        best_cost = 9999999999
        print("getting initial centers")
        centers = self.get_initial_centers(data)
        print("got initial centers")
        best_centers = centers
        self.history = []
        for i in range(self.max_iter):
            assert(torch.all(torch.isfinite(centers)))
            if(self.check_convergence()):
                print("stopping with", self.history)
                break
            # clusters, costs = self.assign_clusters(data, centers)
            clusters,cost = reassign_points_faster(data.numpy(), centers.numpy(), self.z)
            clusters = torch.from_numpy(clusters)
            # cost = torch.sum(costs)
            if cost<best_cost:
                best_cost = cost
                best_centers = centers
            self.history.append(cost)
            # self.history.append(cost.item())
            for j in range(self.k):
                if torch.sum(clusters==j)==0:
                    centers[j] = data[torch.randint(0, data.shape[0], (1,))]
                    continue

                centers[j] = self.find_best_center_iter(data[torch.where(clusters==j)[0]], centers[j])        
        return best_centers
    
    def find_best_center_iter(self, data, center):
        history = []
        max_iter = self.max_sgd
        cluster_tensor = torch.tensor(data)
        center.requires_grad = True
        sgd = torch.optim.AdamW([center], lr=self.lr)
        for i in range(max_iter):
            
            sgd.zero_grad()
            cost = torch.sum(torch.linalg.norm(cluster_tensor-center, axis=1)**self.z)
            if len(history)>3 and history[-3]<0.98*cost:
                break
            history.append(cost)
            cost.backward()
            sgd.step()
        return center.detach()
    
    def get_cost_for_centers(self, data, centers):
        clusters, cost = reassign_points_faster(data.numpy(), centers.numpy(),self.z)
        return cost
        # return np.sum(costs)
    
    def assign_clusters(self,data, centers):
        dists = torch.zeros(data.shape[0], 1)
        clusters = torch.zeros(data.shape[0], 1)
        for i,p in enumerate(data):
            dist = torch.linalg.norm(p-centers, axis=1)
            clusters[i] = torch.argmin(dist)
            dists[i] = torch.min(dist)**self.z
        return clusters, dists
    
    def get_initial_centers(self, data):
        centers = torch.zeros(self.k, data.shape[1])
        centers[0] = data[torch.randint(0, data.shape[0], (1,))]
        for i in range(1,self.k):
            clusters, costs = reassign_points_faster2(data.numpy(), centers[:i].numpy(),self.z)
            clusters = torch.from_numpy(clusters)
            costs = torch.from_numpy(costs)
            costs_np = costs.cpu().numpy().flatten()
            centers[i] = data[np.random.choice(data.shape[0], p=costs_np/np.sum(costs_np))]
        return centers


@jit(nopython=True)
def total_cost(data, centers,z):
    cost = 0
    for i in range(data.shape[0]):
        diff = data[i]-centers
        cost += np.sqrt(np.min(np.sum(diff*diff, axis=1)))**z
    return cost

@jit(nopython=True)
def reassign_points_faster2(data, centers,z):
    clusters = np.zeros(data.shape[0])  
    cost = np.zeros(data.shape[0])
    for i in range(data.shape[0]):
        diff = data[i]-centers
        clusters[i] = np.argmin(np.sum(diff*diff, axis=1))
        cost[i] = np.sqrt(np.min(np.sum(diff*diff, axis=1)))**z
    return clusters, cost

@jit(nopython=True)
def reassign_points_faster(data, centers,z):
    clusters = np.zeros(data.shape[0])  
    cost = 0
    for i in range(data.shape[0]):
        diff = data[i]-centers
        clusters[i] = np.argmin(np.sum(diff*diff, axis=1))
        cost += np.sqrt(np.min(np.sum(diff*diff, axis=1)))**z
    return clusters, cost

@jit(nopython=True)
def assign_clusters_faster(data, centers,z):
    dists = np.zeros(data.shape[0], 1)
    clusters = np.zeros(data.shape[0], 1)
    for i,p in enumerate(data):
        dist = np.linalg.norm(p-centers, axis=1)
        clusters[i] = np.argmin(dist)
        dists[i] = np.min(dist)**z
    return clusters, dists